Machine Learning Workflow: Lasso Regression

Biostat 274

Author

Dr. Jin Zhou @ UCLA

Published

December 23, 2025

Display system information for reproducibility.

sessionInfo()
R version 4.5.1 (2025-06-13)
Platform: aarch64-apple-darwin20
Running under: macOS Sequoia 15.7.3

Matrix products: default
BLAS:   /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRblas.0.dylib 
LAPACK: /Library/Frameworks/R.framework/Versions/4.5-arm64/Resources/lib/libRlapack.dylib;  LAPACK version 3.12.1

locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8

time zone: America/Los_Angeles
tzcode source: internal

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

loaded via a namespace (and not attached):
 [1] htmlwidgets_1.6.4 compiler_4.5.1    fastmap_1.2.0     cli_3.6.5        
 [5] tools_4.5.1       htmltools_0.5.8.1 yaml_2.3.10       rmarkdown_2.29   
 [9] knitr_1.50        jsonlite_2.0.0    xfun_0.53         digest_0.6.37    
[13] rlang_1.1.6       evaluate_1.0.5   
import IPython
print(IPython.sys_info())
{'commit_hash': 'd64897cf0',
 'commit_source': 'installation',
 'default_encoding': 'utf-8',
 'ipython_path': '/Users/jinjinzhou/.virtualenvs/r-reticulate/lib/python3.13/site-packages/IPython',
 'ipython_version': '9.0.1',
 'os_name': 'posix',
 'platform': 'macOS-15.7.3-arm64-arm-64bit-Mach-O',
 'sys_executable': '/Users/jinjinzhou/.virtualenvs/r-reticulate/bin/python',
 'sys_platform': 'darwin',
 'sys_version': '3.13.0 (main, Oct  7 2024, 05:02:14) [Clang 16.0.0 '
                '(clang-1600.0.26.4)]'}

1 What is tidymodels?

The tidymodels framework is a package ecosystem, in which all steps of the machine learning workflow are implemented through dedicated R packages. The consistency of these packages ensures their interoperability and ease of use. Most importantly, the framework should make your machine learning workflow easier to understand and faster to implement.

Below you can see the basic machine learning workflow and how it maps to existing packages from tidymodels:

  1. Preprocess: Transform and prepare data for modeling -> recipes

  2. Model: Select and specify a model for a specific problem -> parsnip

  3. Measure: Evaluate the performance of the model -> yardstick

  4. Sample: Split and sample input data to evaluate models -> rsample

  5. Tune: Adjust and improve the model on input data -> tune

2 Overview

We illustrate the typical machine learning workflow for regression problems using the Hitters data set from R ISLR2 package. The steps are

  1. Initial splitting to test and non-test sets.

  2. Pre-processing of data (pipeline in Python, recipe in R).

  3. Choose a learner/method Lasso in this example.

  4. Tune the hyper-parameter(s) (\(\lambda\) in this example) using \(K\)-fold cross-validation (CV) on the non-test data.

  5. Choose the best model by CV and refit it on the whole non-test data.

  6. Final prediction on the test data.

These steps completes the process of training and evaluating one machine learning method (lasso in this case). We repeat the same process for other learners, e.g., random forest or neural network, using the same test/non-test and CV split. The final report compares the learners based on CV and test errors.

3 Hitters data

A documentation of the Hitters data is here. The goal is to predict the salary (at opening of 1987 season) of MLB players from their performance metrics in the 1986-7 season.

library(GGally)
library(ISLR2)
library(tidymodels)
Warning: package 'infer' was built under R version 4.5.2
Warning: package 'parsnip' was built under R version 4.5.2
library(tidyverse)

Hitters <- as_tibble(Hitters) %>% print(width = Inf)
# A tibble: 322 × 20
   AtBat  Hits HmRun  Runs   RBI Walks Years CAtBat CHits CHmRun CRuns  CRBI
   <int> <int> <int> <int> <int> <int> <int>  <int> <int>  <int> <int> <int>
 1   293    66     1    30    29    14     1    293    66      1    30    29
 2   315    81     7    24    38    39    14   3449   835     69   321   414
 3   479   130    18    66    72    76     3   1624   457     63   224   266
 4   496   141    20    65    78    37    11   5628  1575    225   828   838
 5   321    87    10    39    42    30     2    396   101     12    48    46
 6   594   169     4    74    51    35    11   4408  1133     19   501   336
 7   185    37     1    23     8    21     2    214    42      1    30     9
 8   298    73     0    24    24     7     3    509   108      0    41    37
 9   323    81     6    26    32     8     2    341    86      6    32    34
10   401    92    17    49    66    65    13   5206  1332    253   784   890
   CWalks League Division PutOuts Assists Errors Salary NewLeague
    <int> <fct>  <fct>      <int>   <int>  <int>  <dbl> <fct>    
 1     14 A      E            446      33     20   NA   A        
 2    375 N      W            632      43     10  475   N        
 3    263 A      W            880      82     14  480   A        
 4    354 N      E            200      11      3  500   N        
 5     33 N      E            805      40      4   91.5 N        
 6    194 A      W            282     421     25  750   A        
 7     24 N      E             76     127      7   70   A        
 8     12 A      W            121     283      9  100   A        
 9      8 N      W            143     290     19   75   N        
10    866 A      E              0       0      0 1100   A        
# ℹ 312 more rows
# Numerical summaries
summary(Hitters)
     AtBat            Hits         HmRun            Runs       
 Min.   : 16.0   Min.   :  1   Min.   : 0.00   Min.   :  0.00  
 1st Qu.:255.2   1st Qu.: 64   1st Qu.: 4.00   1st Qu.: 30.25  
 Median :379.5   Median : 96   Median : 8.00   Median : 48.00  
 Mean   :380.9   Mean   :101   Mean   :10.77   Mean   : 50.91  
 3rd Qu.:512.0   3rd Qu.:137   3rd Qu.:16.00   3rd Qu.: 69.00  
 Max.   :687.0   Max.   :238   Max.   :40.00   Max.   :130.00  
                                                               
      RBI             Walks            Years            CAtBat       
 Min.   :  0.00   Min.   :  0.00   Min.   : 1.000   Min.   :   19.0  
 1st Qu.: 28.00   1st Qu.: 22.00   1st Qu.: 4.000   1st Qu.:  816.8  
 Median : 44.00   Median : 35.00   Median : 6.000   Median : 1928.0  
 Mean   : 48.03   Mean   : 38.74   Mean   : 7.444   Mean   : 2648.7  
 3rd Qu.: 64.75   3rd Qu.: 53.00   3rd Qu.:11.000   3rd Qu.: 3924.2  
 Max.   :121.00   Max.   :105.00   Max.   :24.000   Max.   :14053.0  
                                                                     
     CHits            CHmRun           CRuns             CRBI        
 Min.   :   4.0   Min.   :  0.00   Min.   :   1.0   Min.   :   0.00  
 1st Qu.: 209.0   1st Qu.: 14.00   1st Qu.: 100.2   1st Qu.:  88.75  
 Median : 508.0   Median : 37.50   Median : 247.0   Median : 220.50  
 Mean   : 717.6   Mean   : 69.49   Mean   : 358.8   Mean   : 330.12  
 3rd Qu.:1059.2   3rd Qu.: 90.00   3rd Qu.: 526.2   3rd Qu.: 426.25  
 Max.   :4256.0   Max.   :548.00   Max.   :2165.0   Max.   :1659.00  
                                                                     
     CWalks        League  Division    PutOuts          Assists     
 Min.   :   0.00   A:175   E:157    Min.   :   0.0   Min.   :  0.0  
 1st Qu.:  67.25   N:147   W:165    1st Qu.: 109.2   1st Qu.:  7.0  
 Median : 170.50                    Median : 212.0   Median : 39.5  
 Mean   : 260.24                    Mean   : 288.9   Mean   :106.9  
 3rd Qu.: 339.25                    3rd Qu.: 325.0   3rd Qu.:166.0  
 Max.   :1566.00                    Max.   :1378.0   Max.   :492.0  
                                                                    
     Errors          Salary       NewLeague
 Min.   : 0.00   Min.   :  67.5   A:176    
 1st Qu.: 3.00   1st Qu.: 190.0   N:146    
 Median : 6.00   Median : 425.0            
 Mean   : 8.04   Mean   : 535.9            
 3rd Qu.:11.00   3rd Qu.: 750.0            
 Max.   :32.00   Max.   :2460.0            
                 NA's   :59                

Graphical summary takes longer to run so suppressed here.

# Graphical summaries
ggpairs(
  data = Hitters, 
  mapping = aes(alpha = 0.25), 
  lower = list(continuous = "smooth")
  ) + 
  labs(title = "Hitters Data")

There are 59 NAs for the salary. Let’s drop those cases. We are left with 263 data points.

Hitters <- Hitters %>%
  drop_na()
dim(Hitters)
[1] 263  20
# Load the pandas library
import pandas as pd
# Load numpy for array manipulation
import numpy as np
# Load seaborn plotting library
import seaborn as sns
import matplotlib.pyplot as plt

# Set font sizes in plots
sns.set(font_scale = 2)
# Display all columns
pd.set_option('display.max_columns', None)

Hitters = pd.read_csv("../data/Hitters.csv")
Hitters
     AtBat  Hits  HmRun  Runs  RBI  Walks  Years  CAtBat  CHits  CHmRun  \
0      293    66      1    30   29     14      1     293     66       1   
1      315    81      7    24   38     39     14    3449    835      69   
2      479   130     18    66   72     76      3    1624    457      63   
3      496   141     20    65   78     37     11    5628   1575     225   
4      321    87     10    39   42     30      2     396    101      12   
..     ...   ...    ...   ...  ...    ...    ...     ...    ...     ...   
317    497   127      7    65   48     37      5    2703    806      32   
318    492   136      5    76   50     94     12    5511   1511      39   
319    475   126      3    61   43     52      6    1700    433       7   
320    573   144      9    85   60     78      8    3198    857      97   
321    631   170      9    77   44     31     11    4908   1457      30   

     CRuns  CRBI  CWalks League Division  PutOuts  Assists  Errors  Salary  \
0       30    29      14      A        E      446       33      20     NaN   
1      321   414     375      N        W      632       43      10   475.0   
2      224   266     263      A        W      880       82      14   480.0   
3      828   838     354      N        E      200       11       3   500.0   
4       48    46      33      N        E      805       40       4    91.5   
..     ...   ...     ...    ...      ...      ...      ...     ...     ...   
317    379   311     138      N        E      325        9       3   700.0   
318    897   451     875      A        E      313      381      20   875.0   
319    217    93     146      A        W       37      113       7   385.0   
320    470   420     332      A        E     1314      131      12   960.0   
321    775   357     249      A        W      408        4       3  1000.0   

    NewLeague  
0           A  
1           N  
2           A  
3           N  
4           N  
..        ...  
317         N  
318         A  
319         A  
320         A  
321         A  

[322 rows x 20 columns]
# Numerical summaries
Hitters.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 322 entries, 0 to 321
Data columns (total 20 columns):
 #   Column     Non-Null Count  Dtype  
---  ------     --------------  -----  
 0   AtBat      322 non-null    int64  
 1   Hits       322 non-null    int64  
 2   HmRun      322 non-null    int64  
 3   Runs       322 non-null    int64  
 4   RBI        322 non-null    int64  
 5   Walks      322 non-null    int64  
 6   Years      322 non-null    int64  
 7   CAtBat     322 non-null    int64  
 8   CHits      322 non-null    int64  
 9   CHmRun     322 non-null    int64  
 10  CRuns      322 non-null    int64  
 11  CRBI       322 non-null    int64  
 12  CWalks     322 non-null    int64  
 13  League     322 non-null    object 
 14  Division   322 non-null    object 
 15  PutOuts    322 non-null    int64  
 16  Assists    322 non-null    int64  
 17  Errors     322 non-null    int64  
 18  Salary     263 non-null    float64
 19  NewLeague  322 non-null    object 
dtypes: float64(1), int64(16), object(3)
memory usage: 50.4+ KB
Hitters.describe()
            AtBat        Hits       HmRun        Runs         RBI       Walks  \
count  322.000000  322.000000  322.000000  322.000000  322.000000  322.000000   
mean   380.928571  101.024845   10.770186   50.909938   48.027950   38.742236   
std    153.404981   46.454741    8.709037   26.024095   26.166895   21.639327   
min     16.000000    1.000000    0.000000    0.000000    0.000000    0.000000   
25%    255.250000   64.000000    4.000000   30.250000   28.000000   22.000000   
50%    379.500000   96.000000    8.000000   48.000000   44.000000   35.000000   
75%    512.000000  137.000000   16.000000   69.000000   64.750000   53.000000   
max    687.000000  238.000000   40.000000  130.000000  121.000000  105.000000   

            Years       CAtBat        CHits      CHmRun        CRuns  \
count  322.000000    322.00000   322.000000  322.000000   322.000000   
mean     7.444099   2648.68323   717.571429   69.490683   358.795031   
std      4.926087   2324.20587   654.472627   86.266061   334.105886   
min      1.000000     19.00000     4.000000    0.000000     1.000000   
25%      4.000000    816.75000   209.000000   14.000000   100.250000   
50%      6.000000   1928.00000   508.000000   37.500000   247.000000   
75%     11.000000   3924.25000  1059.250000   90.000000   526.250000   
max     24.000000  14053.00000  4256.000000  548.000000  2165.000000   

              CRBI       CWalks      PutOuts     Assists      Errors  \
count   322.000000   322.000000   322.000000  322.000000  322.000000   
mean    330.118012   260.239130   288.937888  106.913043    8.040373   
std     333.219617   267.058085   280.704614  136.854876    6.368359   
min       0.000000     0.000000     0.000000    0.000000    0.000000   
25%      88.750000    67.250000   109.250000    7.000000    3.000000   
50%     220.500000   170.500000   212.000000   39.500000    6.000000   
75%     426.250000   339.250000   325.000000  166.000000   11.000000   
max    1659.000000  1566.000000  1378.000000  492.000000   32.000000   

            Salary  
count   263.000000  
mean    535.925882  
std     451.118681  
min      67.500000  
25%     190.000000  
50%     425.000000  
75%     750.000000  
max    2460.000000  

Graphical summary takes longer to run so suppressed here.

# Graphical summaries
sns.pairplot(data = Hitters)

There are 59 NAs for the salary. Let’s drop those cases. We are left with 263 data points.

Hitters.dropna(inplace = True)
Hitters.shape
(263, 20)

4 Initial split into test and non-test sets

# For reproducibility
set.seed(425)
data_split <- initial_split(
  Hitters, 
  # # stratify by percentilesk
  # strata = "Salary", 
  prop = 0.75
  )

Hitters_other <- training(data_split)
dim(Hitters_other)
[1] 197  20
Hitters_test <- testing(data_split)
dim(Hitters_test)
[1] 66 20
from sklearn.model_selection import train_test_split

Hitters_other, Hitters_test = train_test_split(
  Hitters, 
  train_size = 0.75,
  random_state = 425, # seed
  )
Hitters_test.shape
(66, 20)
Hitters_other.shape
(197, 20)

Separate \(X\) and \(y\).

# Non-test X and y
X_other = Hitters_other.drop('Salary', axis = 1)
y_other = Hitters_other.Salary
# Test X and y
X_test = Hitters_test.drop('Salary', axis = 1)
y_test = Hitters_test.Salary

5 Preprocessing (Python) or recipe (R)

For regularization methods such as ridge and lasso, it is essential to center and scale predictors.

norm_recipe <- 
  recipe(
    Salary ~ ., 
    data = Hitters_other
  ) %>%
  # create traditional dummy variables
  step_dummy(all_nominal()) %>%
  # zero-variance filter
  step_zv(all_predictors()) %>% 
  # center and scale numeric data
  step_normalize(all_predictors()) # %>%
  # step_log(Salary, base = 10) %>%
  # estimate the means and standard deviations
  # prep(training = Hitters_other, retain = TRUE)
norm_recipe

Pre-processor for one-hot coding of categorical variables and then standardizing all numeric predictors.

from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.compose import make_column_transformer
from sklearn.pipeline import Pipeline

# OHE transformer for categorical variables
cattf = make_column_transformer(
  (OneHotEncoder(drop = 'first'), ['League', 'Division', 'NewLeague']),
  remainder = 'passthrough'
)
# Standardization transformer
scalar = StandardScaler()

6 Model

lasso_mod <- 
  # mixture = 0 (ridge), mixture = 1 (lasso)
  linear_reg(penalty = tune(), mixture = 1.0) %>% 
  set_engine("glmnet")
lasso_mod
Linear Regression Model Specification (regression)

Main Arguments:
  penalty = tune()
  mixture = 1

Computational engine: glmnet 
from sklearn.linear_model import Lasso

lasso = Lasso(max_iter = 10000)
lasso
Lasso(max_iter=10000)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

7 Pipeline (Python) or workflow (R)

Here we bundle the preprocessing step (Python) or recipe (R) and model.

lr_wf <- 
  workflow() %>%
  add_model(lasso_mod) %>%
  add_recipe(norm_recipe)
lr_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)

Main Arguments:
  penalty = tune()
  mixture = 1

Computational engine: glmnet 
pipe = Pipeline(steps = [
  ("cat_tf", cattf),
  ("std_tf", scalar), 
  ("model", lasso)
  ])
pipe
Pipeline(steps=[('cat_tf',
                 ColumnTransformer(remainder='passthrough',
                                   transformers=[('onehotencoder',
                                                  OneHotEncoder(drop='first'),
                                                  ['League', 'Division',
                                                   'NewLeague'])])),
                ('std_tf', StandardScaler()),
                ('model', Lasso(max_iter=10000))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

8 Tuning grid

Set up the grid for tuning in the range of \(10^{-2}-10^3\).

https://dials.tidymodels.org/reference/dials-package.html

lambda_grid <-
  grid_regular(penalty(range = c(-2, 1.5), trans = log10_trans()), levels = 100)
lambda_grid
# A tibble: 100 × 1
   penalty
     <dbl>
 1  0.01  
 2  0.0108
 3  0.0118
 4  0.0128
 5  0.0138
 6  0.0150
 7  0.0163
 8  0.0177
 9  0.0192
10  0.0208
# ℹ 90 more rows
# Tune hyper-parameter(s)
alphas = np.logspace(start = -3, stop = 2, base = 10, num = 100)
tuned_parameters = {"model__alpha": alphas}

9 Cross-validation (CV)

Set cross-validation partitions.

set.seed(250)
folds <- vfold_cv(Hitters_other, v = 10)
folds
#  10-fold cross-validation 
# A tibble: 10 × 2
   splits           id    
   <list>           <chr> 
 1 <split [177/20]> Fold01
 2 <split [177/20]> Fold02
 3 <split [177/20]> Fold03
 4 <split [177/20]> Fold04
 5 <split [177/20]> Fold05
 6 <split [177/20]> Fold06
 7 <split [177/20]> Fold07
 8 <split [178/19]> Fold08
 9 <split [178/19]> Fold09
10 <split [178/19]> Fold10

Fit cross-validation.

lasso_fit <- 
  lr_wf %>%
  tune_grid(
    resamples = folds,
    grid = lambda_grid)
lasso_fit
# Tuning results
# 10-fold cross-validation 
# A tibble: 10 × 4
   splits           id     .metrics           .notes          
   <list>           <chr>  <list>             <list>          
 1 <split [177/20]> Fold01 <tibble [200 × 5]> <tibble [0 × 4]>
 2 <split [177/20]> Fold02 <tibble [200 × 5]> <tibble [0 × 4]>
 3 <split [177/20]> Fold03 <tibble [200 × 5]> <tibble [0 × 4]>
 4 <split [177/20]> Fold04 <tibble [200 × 5]> <tibble [0 × 4]>
 5 <split [177/20]> Fold05 <tibble [200 × 5]> <tibble [0 × 4]>
 6 <split [177/20]> Fold06 <tibble [200 × 5]> <tibble [0 × 4]>
 7 <split [177/20]> Fold07 <tibble [200 × 5]> <tibble [0 × 4]>
 8 <split [178/19]> Fold08 <tibble [200 × 5]> <tibble [0 × 4]>
 9 <split [178/19]> Fold09 <tibble [200 × 5]> <tibble [0 × 4]>
10 <split [178/19]> Fold10 <tibble [200 × 5]> <tibble [0 × 4]>

Visualize CV criterion.

lasso_fit %>%
  collect_metrics() %>%
  print(width = Inf) %>%
  filter(.metric == "rmse") %>%
  ggplot(mapping = aes(x = penalty, y = mean)) + 
  geom_point() + 
  geom_line() + 
  labs(x = "Penalty", y = "CV RMSE") + 
  scale_x_log10(labels = scales::label_number())
# A tibble: 200 × 7
   penalty .metric .estimator    mean     n std_err .config          
     <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>            
 1  0.01   rmse    standard   341.       10 33.2    pre0_mod001_post0
 2  0.01   rsq     standard     0.472    10  0.0845 pre0_mod001_post0
 3  0.0108 rmse    standard   341.       10 33.2    pre0_mod002_post0
 4  0.0108 rsq     standard     0.472    10  0.0845 pre0_mod002_post0
 5  0.0118 rmse    standard   341.       10 33.2    pre0_mod003_post0
 6  0.0118 rsq     standard     0.472    10  0.0845 pre0_mod003_post0
 7  0.0128 rmse    standard   341.       10 33.2    pre0_mod004_post0
 8  0.0128 rsq     standard     0.472    10  0.0845 pre0_mod004_post0
 9  0.0138 rmse    standard   341.       10 33.2    pre0_mod005_post0
10  0.0138 rsq     standard     0.472    10  0.0845 pre0_mod005_post0
# ℹ 190 more rows

Show the top 5 models (\(\lambda\) values)

lasso_fit %>%
  show_best(metric = "rmse")
# A tibble: 5 × 7
  penalty .metric .estimator  mean     n std_err .config          
    <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>            
1    4.13 rmse    standard    338.    10    38.0 pre0_mod075_post0
2    4.48 rmse    standard    338.    10    38.2 pre0_mod076_post0
3    3.81 rmse    standard    338.    10    37.9 pre0_mod074_post0
4    4.86 rmse    standard    338.    10    38.4 pre0_mod077_post0
5    3.51 rmse    standard    338.    10    37.7 pre0_mod073_post0

Let’s select the best model

best_lasso <- lasso_fit %>%
  select_best(metric = "rmse")
best_lasso
# A tibble: 1 × 2
  penalty .config          
    <dbl> <chr>            
1    4.13 pre0_mod075_post0

Set up CV partitions and CV criterion.

from sklearn.model_selection import GridSearchCV

# Set up CV
n_folds = 10
search = GridSearchCV(
  pipe, 
  tuned_parameters, 
  cv = n_folds, 
  scoring = "neg_root_mean_squared_error",
  # Refit the best model on the whole data set
  refit = True 
  )

Fit CV. This is typically the most time-consuming step.

# Fit CV
search.fit(X_other, y_other)
GridSearchCV(cv=10,
             estimator=Pipeline(steps=[('cat_tf',
                                        ColumnTransformer(remainder='passthrough',
                                                          transformers=[('onehotencoder',
                                                                         OneHotEncoder(drop='first'),
                                                                         ['League',
                                                                          'Division',
                                                                          'NewLeague'])])),
                                       ('std_tf', StandardScaler()),
                                       ('model', Lasso(max_iter=10000))]),
             param_grid={'model__alpha': array([1.00000000e-03, 1.12332403e-03, 1.26185688e-03, 1.41747416e-03,...
       6.89261210e+00, 7.74263683e+00, 8.69749003e+00, 9.77009957e+00,
       1.09749877e+01, 1.23284674e+01, 1.38488637e+01, 1.55567614e+01,
       1.74752840e+01, 1.96304065e+01, 2.20513074e+01, 2.47707636e+01,
       2.78255940e+01, 3.12571585e+01, 3.51119173e+01, 3.94420606e+01,
       4.43062146e+01, 4.97702356e+01, 5.59081018e+01, 6.28029144e+01,
       7.05480231e+01, 7.92482898e+01, 8.90215085e+01, 1.00000000e+02])},
             scoring='neg_root_mean_squared_error')
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

CV results.

cv_res = pd.DataFrame({
  "alpha": alphas,
  "rmse": -search.cv_results_["mean_test_score"]
  })

plt.figure()
sns.relplot(
  data = cv_res,
  x = "alpha",
  y = "rmse"
  ).set(
    xlabel = "alpha",
    ylabel = "CV RMSE",
    xscale = "log"
);
plt.show()

Best CV RMSE:

-search.best_score_
np.float64(327.5225980405363)

10 Finalize our model

Now we are done tuning. Finally, let’s fit this final model to the whole training data and use our test data to estimate the model performance we expect to see with new data.

# Final workflow
final_wf <- lr_wf %>%
  finalize_workflow(best_lasso)
final_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_dummy()
• step_zv()
• step_normalize()

── Model ───────────────────────────────────────────────────────────────────────
Linear Regression Model Specification (regression)

Main Arguments:
  penalty = 4.13201240011533
  mixture = 1

Computational engine: glmnet 
# Fit the whole training set, then predict the test cases
final_fit <- 
  final_wf %>%
  last_fit(data_split)
final_fit
# Resampling results
# Manual resampling 
# A tibble: 1 × 6
  splits           id               .metrics .notes   .predictions .workflow 
  <list>           <chr>            <list>   <list>   <list>       <list>    
1 <split [197/66]> train/test split <tibble> <tibble> <tibble>     <workflow>
# Test metrics
final_fit %>% collect_metrics()
# A tibble: 2 × 4
  .metric .estimator .estimate .config        
  <chr>   <chr>          <dbl> <chr>          
1 rmse    standard     319.    pre0_mod0_post0
2 rsq     standard       0.412 pre0_mod0_post0

Since we called GridSearchCV with refit = True, the best model fit on the whole non-test data is readily available.

search.best_estimator_
Pipeline(steps=[('cat_tf',
                 ColumnTransformer(remainder='passthrough',
                                   transformers=[('onehotencoder',
                                                  OneHotEncoder(drop='first'),
                                                  ['League', 'Division',
                                                   'NewLeague'])])),
                ('std_tf', StandardScaler()),
                ('model',
                 Lasso(alpha=np.float64(0.4750810162102798), max_iter=10000))])
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

The final prediction RMSE on the test set is

from sklearn.metrics import mean_squared_error

mean_squared_error(y_test, search.best_estimator_.predict(X_test))
129252.73200779222

Test RMSE seems to be a bit off.